{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 在 ONNX IR 中的张量表示\n",
    "\n",
    "参考：[`ir.TensorProtocol`](https://onnxscript.ai/intermediate_representation/tensors.html)\n",
    "\n",
    "ONNX IR提供了 [`ir.TensorProtocol`](https://onnxscript.ai/intermediate_representation/ir_api.html#onnxscript.ir.TensorProtocol) 接口，以便使用不同的数据结构作为张量的后备数据。除了传统的 `onnx.TensorProto` 之外，您还可以使用 `np.ndarray`、`torch.Tensor`、`jax.Array` 以及几乎任何其他东西来表示计算图中的张量。这使得它们可以通过相同的 `TensorProtocol` 接口进行访问和序列化，而在初始化期间不会发生额外的复制。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `ir.TensorProtocol`\n",
    "\n",
    "[`ir.TensorProtocol`](https://microsoft.github.io/onnxscript/intermediate_representation/ir_api.html#onnxscript.ir.TensorProtocol) 定义了一个只读接口，用于表示张量。实现该接口的张量类具有 `name`、 `shape`、 `dtype`、 `size`、 `nbytes` 和 `metadata_props` 等属性，用于描述张量的基本属性。此外，它还应实现两个方法 `numpy` 和 `__array__` ，这两个方法将从底层数据生成等效的 NumPy 数组。\n",
    "\n",
    "```{note}\n",
    "当与初始化器、常量值和张量属性交互时，最好假设使用 `ir.TensorProtocol`，只有在需要检查具体类时才使用 {func}`isinstance`。\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 张量类\n",
    "\n",
    "### `ir.TensorProtoTensor`\n",
    "\n",
    "使用 [`ir.TensorProtoTensor`](https://microsoft.github.io/onnxscript/intermediate_representation/ir_api.html#onnxscript.ir.TensorProtoTensor) 作为对 proto 的包装以实现 `ir.TensorProtocol` 接口。您可以像往常一样访问 `shape`、 `dtype` 等。只有在调用 `numpy()` 时才会产生副本。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "直接初始化 `ir.TensorProtoTensor`，如下所示，是可能的。然而，通常建议使用 `ir.serde.deserialize_tensor`，因为它可以处理所有类型的 `TensorProto` （例如，`ir.TensorProtoTensor` 不处理外部张量）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor:  TensorProtoTensor<INT16,[3]>(name='tensor')\n",
      "shape:  [3]\n",
      "dtype:  INT16\n",
      "True\n",
      "tobytes:  b'\\x01\\x00\\x02\\x00\\x03\\x00'\n",
      "numpy:  [1 2 3]\n"
     ]
    }
   ],
   "source": [
    "import onnx\n",
    "from onnxscript import ir\n",
    "\n",
    "tensor_proto = onnx.helper.make_tensor(\"tensor\", onnx.TensorProto.INT16, (3,), [1, 2, 3])\n",
    "tensor = ir.TensorProtoTensor(tensor_proto)\n",
    "print(\"tensor: \", tensor)  # TensorProtoTensor<INT16,[3]>(name='tensor')\n",
    "print(\"shape: \", tensor.shape)  # ir.Shape([3])\n",
    "print(\"dtype: \", tensor.dtype)  # ir.DataType.INT16\n",
    "print(tensor.raw == tensor_proto)  # The raw field is the exact tensor_proto provided at initialization\n",
    "print(\"tobytes: \", tensor.tobytes())  # b'\\x01\\x00\\x02\\x00\\x03\\x00'\n",
    "print(\"numpy: \", tensor.numpy())  # array([1, 2, 3], dtype=int16)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `ir.ExternalTensor`\n",
    "\n",
    "存储在外部磁盘上的张量数据通常很大，加载时将占用内存。[`ir.ExternalTensor`](https://microsoft.github.io/onnxscript/intermediate_representation/ir_api.html#onnxscript.ir.ExternalTensor) 类使用内存映射来避免将张量加载到内存中。您可以使用张量作为普通的 NumPy 数组，内存使用量最小。\n",
    "\n",
    "请参阅 `ir.serde.deserialize_tensor` 以找到将 `onnx.TensorProto` 转换为 `ir.ExternalTensor` 的示例。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `ir.Tensor`\n",
    "\n",
    "`ir.Tensor` 是围绕 NumPy 数组兼容的数组对象（如 `np.ndarray` 和 `torch.Tensor`）的包装器。它最适合创建内存中的张量，而不将其转换为 `TensorProto`，以减少转换开销。\n",
    "\n",
    "```{note}\n",
    "如果一个数组对象定义了 `__array__` 方法，则它是兼容的。\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "从数组创建张量，只需用 NumPy 数组初始化即可"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Tensor<DOUBLE,[1,2]>(array([[0.71395225, 0.48701339]]), name=None)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "tensor = ir.Tensor(np.random.rand(1, 2))\n",
    "tensor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "初始化器将从数组中获取数据类型和形状信息。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如果要从 NumPy 数组以外的对象创建张量，您需要指定数据类型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1. 2. 3.]\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from onnxscript import ir\n",
    "\n",
    "torch_tensor = torch.tensor([1, 2, 3], dtype=torch.float16)\n",
    "tensor = ir.Tensor(torch_tensor, dtype=ir.DataType.FLOAT16)\n",
    "print(tensor.numpy())  # array([1., 2., 3.], dtype=float16)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 字符串张量\n",
    "\n",
    "使用 [`ir.StringTensor`](https://microsoft.github.io/onnxscript/intermediate_representation/ir_api.html#onnxscript.ir.StringTensor) 创建字符串张量。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 回退 `TensorProto`\n",
    "\n",
    "在以下场景中，展示了如何从 `TensorProto` 转换到 `ir.Tensor`，运行一些计算，然后将其转换回 `ir.Tensor`，最后 `TensorProto`："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "tags": [
     "hide-cell"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor: TensorProtoTensor<FLOAT16,[2,3]>(name='tensor')\n",
      "tensor.numpy(): [[1. 2. 3.]\n",
      " [4. 5. 6.]]\n",
      "tensor.tobytes(): b'\\x00<\\x00@\\x00B\\x00D\\x00E\\x00F'\n",
      "mean: [2.5 3.5 4.5]\n",
      "tensor_mean: Tensor<FLOAT16,[3]>(array([2.5, 3.5, 4.5], dtype=float16), name=None)\n",
      "mean_tensor_proto: dims: 3\n",
      "data_type: 10\n",
      "raw_data: \"\\000A\\000C\\200D\"\n",
      "\n",
      "onnx.numpy_helper.to_array(mean_tensor_proto): [2.5 3.5 4.5]\n",
      "tensor_mean.tobytes(): b'\\x00A\\x00C\\x80D'\n",
      "Bytes same as proto: True\n",
      "\n",
      "# Explore other methods defined by TensorProtocol:\n",
      "tensor_mean.shape: [3]\n",
      "tensor_mean.dtype: FLOAT16\n",
      "tensor_mean.name: None\n",
      "tensor_mean.doc_string: None\n",
      "tensor_mean.raw: [2.5 3.5 4.5]\n",
      "tensor_mean.metadata_props: {}\n",
      "tensor_mean.size: 3\n",
      "tensor_mean.nbytes: 6\n",
      "tensor_mean.raw: [2.5 3.5 4.5]\n"
     ]
    }
   ],
   "source": [
    "from onnxscript import ir\n",
    "import onnx\n",
    "import numpy as np\n",
    "\n",
    "# 1. Create the TensorProto\n",
    "proto = onnx.helper.make_tensor(\n",
    "    \"tensor\", onnx.TensorProto.FLOAT16, [2, 3], [1, 2, 3, 4, 5, 6]\n",
    ")\n",
    "\n",
    "# 2. Create an IR Tensor from the Protobuf message\n",
    "tensor = ir.serde.deserialize_tensor(proto)\n",
    "# Note that we get a TensorProtoTensor that implements the TensorProtocol\n",
    "print(\"tensor:\", tensor)  # TensorProtoTensor<FLOAT16,[2,3]>(name='tensor')\n",
    "print(\"tensor.numpy():\", tensor.numpy())   # [[1. 2. 3.]\n",
    "                                           #  [4. 5. 6.]]\n",
    "print(\"tensor.tobytes():\", tensor.tobytes())  # b'\\x00<\\x00@\\x00B\\x00D\\x00E\\x00F'\n",
    "\n",
    "# 3. Do computation using numpy\n",
    "mean = tensor.numpy().mean(axis=0)\n",
    "print(\"mean:\", mean)  # array([2.5, 3.5, 4.5], dtype=float16)\n",
    "\n",
    "# 4. Create a Tensor from the ndarray. Note that we use ir.Tensor\n",
    "tensor_mean = ir.Tensor(mean)\n",
    "print(\"tensor_mean:\", tensor_mean)  # Tensor<FLOAT16,[3]>(array([2.5, 3.5, 4.5], dtype=float16), name='')\n",
    "\n",
    "# 5. Obtain the TensorProto from ir.Tensor\n",
    "mean_tensor_proto: onnx.TensorProto = ir.serde.serialize_tensor(tensor_mean)\n",
    "print(\"mean_tensor_proto:\", mean_tensor_proto)\n",
    "print(\n",
    "    \"onnx.numpy_helper.to_array(mean_tensor_proto):\",\n",
    "    onnx.numpy_helper.to_array(mean_tensor_proto)\n",
    "    # array([2.5, 3.5, 4.5], dtype=float16)\n",
    ")\n",
    "\n",
    "# You can obtain the bytes data as well\n",
    "print(\"tensor_mean.tobytes():\", tensor_mean.tobytes())\n",
    "print(\"Bytes same as proto:\", mean_tensor_proto.raw_data == tensor_mean.tobytes())\n",
    "\n",
    "# Explore other methods defined by TensorProtocol:\n",
    "print(\"\\n# Explore other methods defined by TensorProtocol:\")\n",
    "print(\"tensor_mean.shape:\", tensor_mean.shape)\n",
    "print(\"tensor_mean.dtype:\", tensor_mean.dtype)\n",
    "print(\"tensor_mean.name:\", tensor_mean.name)\n",
    "print(\"tensor_mean.doc_string:\", tensor_mean.doc_string)\n",
    "print(\"tensor_mean.raw:\", tensor_mean.raw)\n",
    "print(\"tensor_mean.metadata_props:\", tensor_mean.metadata_props)\n",
    "print(\"tensor_mean.size:\", tensor_mean.size)\n",
    "print(\"tensor_mean.nbytes:\", tensor_mean.nbytes)\n",
    "print(\"tensor_mean.raw:\", tensor_mean.raw)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 处理非原生 NumPy 数据类型：`bfloat16`、`float8`、`int4`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`ir.Tensor.numpy()` 生成张量值的 NumPy 数组表示。当张量的数据类型为 `BFLOAT16` 、 `FLOAT8[...]` 或 `[U]INT4` 时，这些类型不支持 NumPy，将使用 `ml_dtypes` 包中的数据类型。\n",
    "\n",
    "`uint4` / `int4` 总是解包；`tobyte()` 生成预期的打包表示。\n",
    "\n",
    "`ir.Tensor` 的初始化需要 NumPy 数组遵循以下类型约束，或者具有 `ml_dtypes` 数据类型。\n",
    "- `int8` 用于（未打包的） `int4`，符号位扩展到 8 位。\n",
    "- `uint8` 用于（解包的）`uint4`。\n",
    "- `uint8` 用于 8 位数据类型，如 `float8`。\n",
    "- `uint16` 用于 `bfloat16`。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "以下示例展示了如何创建 `FLOAT8E4M3FN` 张量，变换其值，并创建新的张量来存储变换后的值。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor<FLOAT8E4M3FN,[2]>(array([0.00195312, 0.00585938], dtype='float8_e4m3fn'), name=None)\n",
      "tensor.numpy(): [0.00195312 0.00585938]\n",
      "times_100: [0.1875 0.5625]\n",
      "new_tensor: Tensor<FLOAT8E4M3FN,[2]>(array([0.1875, 0.5625], dtype='float8_e4m3fn'), name=None)\n",
      "new_tensor == times_100 [ True  True]\n"
     ]
    }
   ],
   "source": [
    "from onnxscript import ir\n",
    "import numpy as np\n",
    "\n",
    "array = np.array([0b1, 0b11], dtype=np.uint8)\n",
    "# The array is reinterpreted using the ml_dtypes package\n",
    "tensor = ir.Tensor(array, dtype=ir.DataType.FLOAT8E4M3FN)\n",
    "print(tensor)  # Tensor<FLOAT8E4M3FN,[2]>(array([0.00195312, 0.00585938], dtype='float8_e4m3fn'), name=None)\n",
    "print(\"tensor.numpy():\", tensor.numpy())  # [0.00195312 0.00585938]\n",
    "\n",
    "# Compute\n",
    "times_100 = tensor.numpy() * 100\n",
    "print(\"times_100:\", times_100)\n",
    "\n",
    "# Create a new tensor out of the new value; dtype must be specified\n",
    "new_tensor = ir.Tensor(times_100.view(np.uint8), dtype=ir.DataType.FLOAT8E4M3FN)\n",
    "# You can also directly create the tensor from the float8 array without specifying dtype\n",
    "# new_tensor = ir.Tensor(times_100)\n",
    "print(\"new_tensor:\", new_tensor)  # Tensor<FLOAT8E4M3FN,[2]>(array([0.1875, 0.5625], dtype='float8_e4m3fn'), name=None)\n",
    "print(\"new_tensor == times_100\", new_tensor.numpy() == times_100)  # array([ True,  True])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 高级用法\n",
    "\n",
    "### 子类化 `ir.Tensor`以实现更高效的访问和更广泛的支持 `dtype`\n",
    "\n",
    "`ir.Tensor` 内部将任何与数组兼容的对象转换为 NumPy 数组，以生成 `tobytes()` 中的字节表示。由于额外的转换，这可能会降低效率。它还限制了对于 NumPy 不支持的数据类型（如 `bfloat16`）的支持，因为 `__array__` 方法将失败。\n",
    "\n",
    "为了完全支持来自其他框架的数组，通常创建专门的类来处理它们是个好主意。下面的 `TorchTensor` 类演示了您如何通过子类化 `ir.Tensor` 来处理 PyTorch 张量："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "tags": [
     "hide-cell"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor:  TorchTensor<BFLOAT16,[3]>(tensor([1., 2., 3.], dtype=torch.bfloat16), name=None)\n",
      "numpy:  [16256 16384 16448]\n",
      "tobytes:  b'\\x80?\\x00@@@'\n",
      "nbytes:  6\n"
     ]
    }
   ],
   "source": [
    "import ctypes\n",
    "from typing import Any\n",
    "\n",
    "import torch\n",
    "from onnxscript import ir\n",
    "\n",
    "# Define utilities to convert PyTorch data types so users do not need to specify manually\n",
    "_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = {\n",
    "    torch.bfloat16: ir.DataType.BFLOAT16,\n",
    "    torch.bool: ir.DataType.BOOL,\n",
    "    torch.complex128: ir.DataType.COMPLEX128,\n",
    "    torch.complex64: ir.DataType.COMPLEX64,\n",
    "    torch.float16: ir.DataType.FLOAT16,\n",
    "    torch.float32: ir.DataType.FLOAT,\n",
    "    torch.float64: ir.DataType.DOUBLE,\n",
    "    torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN,\n",
    "    torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,\n",
    "    torch.float8_e5m2: ir.DataType.FLOAT8E5M2,\n",
    "    torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,\n",
    "    torch.int16: ir.DataType.INT16,\n",
    "    torch.int32: ir.DataType.INT32,\n",
    "    torch.int64: ir.DataType.INT64,\n",
    "    torch.int8: ir.DataType.INT8,\n",
    "    torch.uint8: ir.DataType.UINT8,\n",
    "}\n",
    "\n",
    "\n",
    "def _torch_dtype_to_onnx_dtype(dtype: torch.dtype) -> ir.DataType:\n",
    "    return _TORCH_DTYPE_TO_ONNX[dtype]\n",
    "\n",
    "class TorchTensor(ir.Tensor):\n",
    "    def __init__(self, tensor: torch.Tensor):\n",
    "        # Pass the tensor as the raw data to ir.Tensor's constructor\n",
    "        super().__init__(tensor, dtype=_torch_dtype_to_onnx_dtype(tensor.dtype))\n",
    "\n",
    "    def __array__(self, dtype: Any = None) -> \"np.ndarray\":\n",
    "        # numpy() calls __array__ in ir.Tensor\n",
    "        if self.dtype == ir.DataType.BFLOAT16:\n",
    "            return self.raw.view(torch.uint16).__array__(dtype)\n",
    "        if self.dtype in {\n",
    "            ir.DataType.FLOAT8E4M3FN,\n",
    "            ir.DataType.FLOAT8E4M3FNUZ,\n",
    "            ir.DataType.FLOAT8E5M2,\n",
    "            ir.DataType.FLOAT8E5M2FNUZ\n",
    "        }:\n",
    "            return self.raw.view(torch.uint8).__array__(dtype)\n",
    "        return self.raw.__array__(dtype)\n",
    "\n",
    "    def tobytes(self) -> bytes:\n",
    "        # Implement tobytes to support native PyTorch types so we can use types like bloat16\n",
    "        # Reading from memory directly is also more efficient because\n",
    "        # it avoids copying to a NumPy array\n",
    "        tensor = self.raw.detach().cpu().contiguous()\n",
    "        return bytes(\n",
    "            (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address(\n",
    "                tensor.data_ptr()\n",
    "            )\n",
    "        )\n",
    "\n",
    "# Test the implementation\n",
    "torch_tensor = torch.tensor([1,2,3], dtype=torch.bfloat16)\n",
    "tensor = TorchTensor(torch_tensor)\n",
    "print(\"tensor: \", tensor)\n",
    "print(\"numpy: \", tensor.numpy())\n",
    "print(\"tobytes: \", tensor.tobytes())  # b'\\x80?\\x00@@@'\n",
    "print(\"nbytes: \", tensor.nbytes)  # 6"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "该类实现了 `tobytes()` ，以生成张量在序列化为 ONNX 文件/TensorProto 时的正确字节表示。该类还实现了 `__array__()` 方法，以返回 NumPy 不支持的数据类型的位表示。这样，分析阶段仍然可以对这些值进行计算。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 使用不同框架进行计算\n",
    "\n",
    "由于 `ir.Tensor` 实现了 `__array__` 方法和 `__dlpack__` 方法，其内容可以在不复制的情况下与计算框架共享。例如："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "tags": [
     "hide-cell"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[42. 84.]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[300. 800.]\n",
      "[40. 60.]\n",
      "tensor([-20., -20.])\n",
      "<class 'onnx.onnx_ml_pb2.TensorProto'>\n",
      "dims: 2\n",
      "data_type: 1\n",
      "raw_data: \"\\000\\000 A\\000\\000\\240A\"\n",
      "\n",
      "[10. 20.]\n"
     ]
    }
   ],
   "source": [
    "from onnxscript import ir\n",
    "\n",
    "# We can call numpy methods directly on ir.Tensor\n",
    "import numpy as np\n",
    "print(np.multiply(ir.Tensor(np.array([1, 2])), 42))  # array([42., 84.])\n",
    "\n",
    "# We can transfer arrays to different frameworks\n",
    "import jax.numpy as jnp\n",
    "import jax\n",
    "import torch\n",
    "\n",
    "# Create ir.Tensor\n",
    "jax_array = jnp.array([10., 20.])\n",
    "ir_tensor_jax = ir.Tensor(jax_array, dtype=ir.DataType.FLOAT)\n",
    "torch_tensor = torch.tensor([30., 40.])\n",
    "ir_tensor_torch = ir.Tensor(torch_tensor, dtype=ir.DataType.FLOAT)\n",
    "\n",
    "# Use numpy for computation\n",
    "print(np.multiply(ir_tensor_jax, ir_tensor_torch))  # array([300., 800.], dtype=float32)\n",
    "\n",
    "# Use jax for computation by calling from_dlpack to transfer the tensor data without copying when the device is the same\n",
    "jax_array_from_ir = jax.dlpack.from_dlpack(ir_tensor_torch)\n",
    "print(jax_array_from_ir + jax_array)  # [40. 60.]\n",
    "\n",
    "# Use PyTorch for computation\n",
    "torch_tensor_from_ir = torch.from_dlpack(ir_tensor_jax)\n",
    "print(torch_tensor_from_ir - torch_tensor)  # tensor([-20., -20.])\n",
    "\n",
    "# They can all be serialized into TensorProto\n",
    "proto = ir.serde.serialize_tensor(ir_tensor_jax)\n",
    "print(type(proto))  # <class 'onnx.onnx_ml_pb2.TensorProto'>\n",
    "print(proto)\n",
    "\n",
    "# The value is exactly the same as jax_array\n",
    "print(ir.serde.deserialize_tensor(proto).numpy())  # [10. 20.]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这在您在图上创建需要计算具体值的传递时特别有用。您可以使用您喜欢的框架来创建传递。包含新创建的 `ir.Tensor` 的转换图将与下游传递兼容，即使它们利用其他计算框架也是如此。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
